
import torch
from transformers import BertForTokenClassification

class BERTModel:
    def __init__(self, model_path, num_labels, device):
        self.model = BertForTokenClassification.from_pretrained(
            model_path,
            num_labels=num_labels
        ).to(device)

    def get_model(self):
        return self.model

